# -*- coding: utf-8 -*- 
import pickle
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.callbacks import ModelCheckpoint
from sklearn.model_selection import train_test_split
import warnings
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore")   #不显示警告

def plot_train_test(history):
    plt.figure(figsize=(12,4))

    plt.subplots_adjust(left=0.125, bottom=None, right=0.9, top=None,
                    wspace=0.3, hspace=None)
    plt.subplot(1,2,1)
    plt.plot(history.history['loss'],linewidth=3,label='Train')
    plt.plot(history.history['val_loss'],linewidth=3,linestyle='dashed',label='Test')
    plt.xlabel('Epoch',fontsize=20)
    plt.ylabel('loss',fontsize=20)
    plt.legend()

    plt.subplot(1,2,2)
    plt.plot(history.history['mse'],linewidth=3,label='Train')
    plt.plot(history.history['val_mse'],linewidth=3,linestyle='dashed',label='Test')
    plt.xlabel('Epoch',fontsize=20)
    plt.ylabel('MSE',fontsize=20)
    plt.legend()
    plt.show()


def build_ann(inputs, units_list=[100, 200, 100, 50, 25, 10, 1],
            activation='relu',
            optimizer='adam',
            init='glorot_uniform',
            rate=0.2):
                
    '''
    根据参数，搭建一个多层感知器模型
    '''
    ann = tf.keras.Sequential()
    ann.add(layers.Input(shape=(inputs,)))
    for units in units_list:
        # 添加隐藏层
        ann.add(layers.Dense(units=units,
                        activation=activation,
                        kernel_initializer=init))
        # 添加 Dropout 层
        ann.add(layers.Dropout(rate=rate))
    # 输出层
    ann.add(layers.Dense(units=1,
                        activation=activation,
                        kernel_initializer=init))

    ann.compile(loss='mae', 
                optimizer=optimizer,
                metrics=['mse'])
    
    return ann

if __name__ == '__main__':
    data = pickle.load(open(r'../附件/intermedia_data.pkl', 'rb'))

    # 将数据按有无 B-V 拆分成两部分
    X = data.loc[~data['B-V'].isna()]
    X_without_BV = data.loc[data['B-V'].isna()]
    
    # 将 B-V 作为 y
    y = X['B-V']
    X.drop('B-V', axis=1, inplace=True)

    X_train, X_test, y_train, y_test = train_test_split(
                                            X, y, test_size=0.3)
    
    inputs_num = X.shape[1]
    ann = build_ann(inputs_num)
    model_path = r'../附件/ann_model.h5'
    checkpoint = ModelCheckpoint(filepath=model_path,
                              monitor='val_mse',
                              verbose=1,
                              save_best_only=True,mode='min')
    callback_list=[checkpoint]
    
    history = ann.fit(
        X_train,
        y_train,
        batch_size=200,
        epochs=500,
        validation_data=(X_test, y_test),
        callbacks=callback_list,
    )
    
    plot_train_test(history)
